# Post-Training

Post-training in Llama Stack allows you to fine-tune models using various providers and frameworks. This section covers all available post-training providers and how to use them effectively.

## Overview

Llama Stack provides multiple post-training providers:

- **HuggingFace SFTTrainer** (`inline::huggingface`) - Fine-tuning using HuggingFace ecosystem
- **TorchTune** (`inline::torchtune`) - Fine-tuning using Meta's TorchTune framework
- **NVIDIA** (`remote::nvidia`) - Fine-tuning using NVIDIA's platform

## HuggingFace SFTTrainer

[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets.

### Features

- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)

### Configuration

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `str` | No | cuda |  |
| `distributed_backend` | `Literal['fsdp', 'deepspeed']` | No |  |  |
| `checkpoint_format` | `Literal['full_state', 'huggingface']` | No | huggingface |  |
| `chat_template` | `str` | No | |
| `model_specific_config` | `dict` | No | `{'trust_remote_code': True, 'attn_implementation': 'sdpa'}` |  |
| `max_seq_length` | `int` | No | 2048 |  |
| `gradient_checkpointing` | `bool` | No | False |  |
| `save_total_limit` | `int` | No | 3 |  |
| `logging_steps` | `int` | No | 10 |  |
| `warmup_ratio` | `float` | No | 0.1 |  |
| `weight_decay` | `float` | No | 0.01 |  |
| `dataloader_num_workers` | `int` | No | 4 |  |
| `dataloader_pin_memory` | `bool` | No | True |  |

### Sample Configuration

```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
```

### Setup

You can access the HuggingFace trainer via the `starter` distribution:

```bash
llama stack list-deps starter | xargs -L1 uv pip install
llama stack run starter
```

### Usage Example

```python
import time
import uuid

from llama_stack_client.types import (
    post_training_supervised_fine_tune_params,
    algorithm_config_param,
)

def create_http_client():
    from llama_stack_client import LlamaStackClient
    return LlamaStackClient(base_url="http://localhost:8321")

client = create_http_client()

# Example Dataset
client.datasets.register(
    purpose="post-training/messages",
    source={
        "type": "uri",
        "uri": "huggingface://datasets/llamastack/simpleqa?split=train",
    },
    dataset_id="simpleqa",
)

training_config = post_training_supervised_fine_tune_params.TrainingConfig(
    data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
        batch_size=32,
        data_format="instruct",
        dataset_id="simpleqa",
        shuffle=True,
    ),
    gradient_accumulation_steps=1,
    max_steps_per_epoch=0,
    max_validation_steps=1,
    n_epochs=4,
)

algorithm_config = algorithm_config_param.LoraFinetuningConfig(
    alpha=1,
    apply_lora_to_mlp=True,
    apply_lora_to_output=False,
    lora_attn_modules=["q_proj"],
    rank=1,
    type="LoRA",
)

job_uuid = f"test-job{uuid.uuid4()}"

# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"

start_time = time.time()
response = client.post_training.supervised_fine_tune(
    job_uuid=job_uuid,
    logger_config={},
    model=training_model,
    hyperparam_search_config={},
    training_config=training_config,
    algorithm_config=algorithm_config,
    checkpoint_dir="output",
)
print("Job: ", job_uuid)

# Wait for the job to complete!
while True:
    status = client.post_training.job.status(job_uuid=job_uuid)
    if not status:
        print("Job not found")
        break

    print(status)
    if status.status == "completed":
        break

    print("Waiting for job to complete...")
    time.sleep(5)

end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")

print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

## TorchTune

[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.

### Features

- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities
- Support for LoRA

### Configuration

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No |  |  |
| `checkpoint_format` | `Literal['meta', 'huggingface']` | No | meta |  |

### Sample Configuration

```yaml
checkpoint_format: meta
```

### Setup

You can access the TorchTune trainer by writing your own yaml pointing to the provider:

```yaml
post_training:
  - provider_id: torchtune
    provider_type: inline::torchtune
    config: {}
```

You can then build and run your own stack with this provider.

### Usage Example

```python
import time
import uuid

from llama_stack_client.types import (
    post_training_supervised_fine_tune_params,
    algorithm_config_param,
)

def create_http_client():
    from llama_stack_client import LlamaStackClient
    return LlamaStackClient(base_url="http://localhost:8321")

client = create_http_client()

# Example Dataset
client.datasets.register(
    purpose="post-training/messages",
    source={
        "type": "uri",
        "uri": "huggingface://datasets/llamastack/simpleqa?split=train",
    },
    dataset_id="simpleqa",
)

training_config = post_training_supervised_fine_tune_params.TrainingConfig(
    data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
        batch_size=32,
        data_format="instruct",
        dataset_id="simpleqa",
        shuffle=True,
    ),
    gradient_accumulation_steps=1,
    max_steps_per_epoch=0,
    max_validation_steps=1,
    n_epochs=4,
)

algorithm_config = algorithm_config_param.LoraFinetuningConfig(
    alpha=1,
    apply_lora_to_mlp=True,
    apply_lora_to_output=False,
    lora_attn_modules=["q_proj"],
    rank=1,
    type="LoRA",
)

job_uuid = f"test-job{uuid.uuid4()}"

# Example Model
training_model = "meta-llama/Llama-2-7b-hf"

start_time = time.time()
response = client.post_training.supervised_fine_tune(
    job_uuid=job_uuid,
    logger_config={},
    model=training_model,
    hyperparam_search_config={},
    training_config=training_config,
    algorithm_config=algorithm_config,
    checkpoint_dir="output",
)
print("Job: ", job_uuid)

# Wait for the job to complete!
while True:
    status = client.post_training.job.status(job_uuid=job_uuid)
    if not status:
        print("Job not found")
        break

    print(status)
    if status.status == "completed":
        break

    print("Waiting for job to complete...")
    time.sleep(5)

end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")

print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

## NVIDIA

NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.

### Configuration

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No |  | The NVIDIA API key. |
| `dataset_namespace` | `str \| None` | No | default | The NVIDIA dataset namespace. |
| `project_id` | `str \| None` | No | test-example-model@v1 | The NVIDIA project ID. |
| `customizer_url` | `str \| None` | No |  | Base URL for the NeMo Customizer API |
| `timeout` | `int` | No | 300 | Timeout for the NVIDIA Post Training API |
| `max_retries` | `int` | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
| `output_model_dir` | `str` | No | test-example-model@v1 | Directory to save the output model |

### Sample Configuration

```yaml
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
```

## Best Practices

- **Choose the right provider**: Use HuggingFace for broader compatibility, TorchTune for Meta models, or NVIDIA for their ecosystem
- **Configure hardware appropriately**: Ensure your configuration matches your available hardware (CPU, GPU, MPS)
- **Monitor jobs**: Always monitor job status and handle completion appropriately
- **Use appropriate datasets**: Ensure your dataset format matches the expected input format for your chosen provider

## Next Steps

- Check out the [Building Applications - Fine-tuning](../building_applications/index.mdx) guide for application-level examples
- See the [Providers](../providers/post_training/index.mdx) section for detailed provider documentation
- Review the [API Reference](../advanced_apis/post_training.mdx) for complete API documentation
